from pathlib import Path
import efel
import matplotlib.pyplot as plt
from datareuse import Reuse
import numpy as np
import sys
import pandas as pd
from bluepyparallel import evaluate
from matplotlib.backends.backend_pdf import PdfPages
import extra_features

import pyabf


def abf_reader(in_data):
    """Reader for the .abf files obtained from the Broad Institute, Guoping Feng's lab."""
    abf = pyabf.ABF(in_data)
    data = []
    for sweep_id in range(abf.sweepCount):
        abf.setSweep(sweep_id)
        trace_data = {
            "voltage": np.array(abf.sweepY),
            "current": np.array(abf.sweepC),
            "time": np.array(abf.sweepX),
            "dt": 1.0 / abf.dataRate,
            "protocol_name": abf.protocol,
        }
        data.append(trace_data)
    return data


month_map = {
    "Jan": "01",
    "Feb": "02",
    "Mar": "03",
    "Apr": "04",
    "May": "05",
    "Jun": "06",
    "Jul": "07",
    "Aug": "08",
    "Sep": "09",
    "Oct": "10",
    "Nov": "11",
    "Dec": "12",
}


def get_data_df():
    """Get the list of traces to extract features."""
    data_path = Path("../../ephys_data")

    from Spp_files_per_cell import Spp_cells_metadata
    from Ecel_files_per_cell import Ecel_cells_metadata
    from Burst_runaway_files_per_cell import Burst_runaway_cells_metadata

    Path("traces").mkdir(exist_ok=True)
    all_data = {"month": [], "day": [], "cell": [], "rec_id": [], "path": [], "cell_type": []}
    for cell_type, type_data in zip(
        ["spp", "ecel", "runaway"],
        [Spp_cells_metadata, Ecel_cells_metadata, Burst_runaway_cells_metadata],
    ):
        for month, month_data in type_data.items():
            _month = month_map[month]
            for day, day_data in month_data.items():
                for cell, data in day_data.items():
                    for i in data["files_ids"]:
                        for year in ["2020", "2021"]:
                            path = Path(
                                data_path
                                / month
                                / day
                                / f"{year}_{_month}_{int(day):02d}_{i:04d}.abf"
                            )
                            if path.exists():
                                all_data["month"].append(month)
                                all_data["day"].append(day)
                                all_data["cell"].append(cell)
                                all_data["rec_id"].append(i)
                                all_data["path"].append(str(path))
                                all_data["cell_type"].append(cell_type)

    return pd.DataFrame.from_dict(all_data)


if __name__ == "__main__":
    _data_df = get_data_df()
    data_df = _data_df[_data_df["cell_type"] == "runaway"]

    # runaway
    row = data_df.iloc[20]
    print(row)
    data = abf_reader(row["path"])[0]
    time = data["time"] * 1000
    voltage = data["voltage"]
    plt.figure(figsize=(10, 5))
    plt.plot(time, voltage)
    plt.xlabel("time (ms)")
    plt.ylabel("voltage (mV)")
    plt.gca().set_xlim(0, 15000)
    plt.savefig("trace.pdf")
    stim_start = time[data["current"] <= -1][-1]
    efel_trace = {
        "T": time,
        "V": voltage,
        "stim_start": [stim_start],
        "stim_end": [15000],  # time[-1]], # fix the time to cut longer traces
    }
    features = efel.get_feature_values([efel_trace], ["all_burst_number", "peak_time"])[0]
    print(features)
    peak_times = features["peak_time"]
    peak_times = peak_times[(peak_times > stim_start) & (peak_times < 15000)]
    isis = np.diff(peak_times)
    print(isis)
    thresh = extra_features._get_burst_thresh(isis)
    plt.figure(figsize=(5, 3))
    plt.hist(isis, bins=50)
    plt.axvline(thresh, color="r")
    plt.tight_layout()
    plt.xlabel("ISIs (ms)")
    plt.ylabel("# ISIs")
    plt.savefig("isis.pdf")

    # runaway
    row = _data_df[
        (_data_df.month == "Sep")
        & (_data_df.day == "23")
        & (_data_df.cell == "Cell#2")
        & (_data_df.rec_id == 68)
    ].iloc[0]
    print(row)
    data = abf_reader(row["path"])[0]
    time = data["time"] * 1000
    voltage = data["voltage"]
    plt.figure(figsize=(10, 5))
    plt.plot(time, voltage)
    plt.xlabel("time (ms)")
    plt.ylabel("voltage (mV)")
    plt.gca().set_xlim(0, 15000)
    plt.savefig("trace_tonic.pdf")
    stim_start = time[data["current"] <= -1][-1]
    efel_trace = {
        "T": time,
        "V": voltage,
        "stim_start": [stim_start],
        "stim_end": [15000],  # time[-1]], # fix the time to cut longer traces
    }
    features = efel.get_feature_values(
        [efel_trace], ["all_burst_number", "peak_time", "tonic_after_burst"]
    )[0]
    print(features)
    peak_times = features["peak_time"]
    peak_times = peak_times[(peak_times > stim_start) & (peak_times < 15000)]
    isis = np.diff(peak_times)
    print(isis)
    thresh = extra_features._get_burst_thresh(isis)
    n_tonic = 0
    isis_tonic = []
    for i, isi in enumerate(isis):
        if isi > thresh:
            time_shift = 5
            lowest = min(voltage[(time > peak_times[i]) & (time < peak_times[i + 1])])
            slow_lowest = min(
                voltage[(time > peak_times[i] + time_shift) & (time < peak_times[i + 1])]
            )
            if lowest < slow_lowest and isis[i - 1] > thresh:
                n_tonic += 1
                isis_tonic.append(isi)

    plt.figure(figsize=(5, 3))
    plt.hist(isis, bins=50, range=[0,600])
    plt.hist(isis_tonic, bins=50, range=[0, 600])
    plt.axvline(thresh, color="r")
    plt.tight_layout()
    plt.xlabel("ISIs (ms)")
    plt.ylabel("# ISIs")
    plt.savefig("isis_tonic.pdf")


    # burst_runaway
